// Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
// This source file is part of the Cangjie project, licensed under Apache-2.0
// with Runtime Library Exception.
//
// See https://cangjie-lang.cn/pages/LICENSE for license information.

#include "MarkPackage.h"

#include <iostream>

#include "InputFile.h"
#include "Logging.h"
#include "Package.h"
#include "SingleDeclarationSymbolVisitor.h"
#include "Universe.h"

PackageFile* input_to_output(Package* package, const InputFile* input)
{
    const auto file_name = input->path().stem().u8string();
    if (auto* file = (*package)[file_name]) {
        return file;
    }
    return new PackageFile(file_name, package);
}

PackageFile* input_to_output(Package* package, const FileLevelSymbol* symbol)
{
    auto* input_file = symbol->defining_file();
    assert(input_file);
    return input_to_output(package, input_file);
}

bool check_symbol(FileLevelSymbol* symbol)
{
    assert(symbol);
    if (const auto* named = dynamic_cast<NamedTypeSymbol*>(symbol)) {
        if (named->is(NamedTypeSymbol::Kind::SourcePrimitive))
            return false;
        if (named->is(NamedTypeSymbol::Kind::TargetPrimitive))
            return false;
    }
    return symbol->is_file_level();
}

static bool set_package(FileLevelSymbol& symbol)
{
    bool success = true;

    const auto& name = symbol.name();

    bool package_found = false;

    for (auto&& package : packages) {
        if (!package.filters()->apply(name)) {
            continue;
        }

        if (symbol.package()) {
            std::cerr << "Entity `" << name << "` is ambiguous between packages `" << symbol.package()->cangjie_name()
                      << "` and `" << package.cangjie_name() << "`" << std::endl;
            success = false;
            continue;
        }

        symbol.set_output_status(OutputStatus::Root);
        symbol.set_package_file(input_to_output(&package, &symbol));
        package_found = true;
    }

    if (!package_found && verbosity >= LogLevel::TRACE) {
        std::cerr << "Entity `" << name << "` does not match any package filter" << std::endl;
    }
    return success;
}

bool mark_roots()
{
    auto success = true;

    for (auto& member : universe.top_level()) {
        if (!set_package(member)) {
            success = false;
        }
    }

    for (auto&& type : Universe::all_declarations()) {
        // Omit primitive types, as well as types having no definition in source files
        // (those are built-ins like `id`).
        if (type.is(NamedTypeSymbol::Kind::SourcePrimitive) || type.is(NamedTypeSymbol::Kind::TargetPrimitive) ||
            !type.defining_file()) {
            continue;
        }

        if (!set_package(type)) {
            success = false;
        }
    }

    return success;
}

class SymbolReferenceCollector final : public SingleDeclarationSymbolVisitor {

public:
    [[nodiscard]] explicit SymbolReferenceCollector(FileLevelSymbol& symbol)
        : SingleDeclarationSymbolVisitor(true), symbol_(symbol)
    {
    }

    SymbolReferenceCollector(const SymbolReferenceCollector& other) = delete;

    SymbolReferenceCollector(SymbolReferenceCollector&& other) noexcept = delete;

    SymbolReferenceCollector& operator=(const SymbolReferenceCollector& other) = delete;

    SymbolReferenceCollector& operator=(SymbolReferenceCollector&& other) noexcept = delete;

    void visit()
    {
        SingleDeclarationSymbolVisitor::visit(&symbol_);
    }

protected:
    void visit_impl(FileLevelSymbol* owner, FileLevelSymbol* value, SymbolProperty, bool) override
    {
        assert(value);
        if (!owner) {
            // Skip the root type of this visit session to avoid self-referencing of each type.
            return;
        }
        const auto* named_value = dynamic_cast<const NamedTypeSymbol*>(value);
        if (named_value) {
            value = named_value->original();
        }
        if (value != &symbol_ // Self-reference
            && value->is_file_level() && value->defining_file()) {
            if (symbol_.add_reference(*value) && verbosity >= LogLevel::TRACE) {
                std::cerr << "Entity `" << symbol_.name() << "` references `" << value->name() << "`\n";
            }
        }
    }

private:
    FileLevelSymbol& symbol_;
};

class ScopeBuilderStatus final {
    bool success_ = true;
    bool changed_ = false;

public:
    [[nodiscard]] bool success() const
    {
        return success_;
    }

    [[nodiscard]] bool error() const
    {
        return !success();
    }

    [[nodiscard]] bool changed() const
    {
        return changed_;
    }

    void mark_error()
    {
        success_ = false;
    }

    void mark_changed()
    {
        changed_ = true;
    }
};

void add_all_symbol_references()
{
    for (const auto* input_directory : inputs) {
        for (const auto* input_file : *input_directory) {
            for (auto* symbol : *input_file) {
                assert(check_symbol(symbol));

                SymbolReferenceCollector(*symbol).visit();
            }
        }
    }
}

static ScopeBuilderStatus symbol_references_to_packages_pass(bool roots_only)
{
    ScopeBuilderStatus status;

    for (const auto* input_directory : inputs) {
        for (const auto* input_file : *input_directory) {
            for (auto* symbol : *input_file) {
                assert(check_symbol(symbol));

                if (symbol->output_status() != (roots_only ? OutputStatus::Root : OutputStatus::Referenced)) {
                    continue;
                }

                auto* package = symbol->package();
                assert(package);

                for (auto* reference : symbol->references_symbols()) {
                    assert(reference);
                    switch (reference->output_status()) {
                        case OutputStatus::Undefined: {
                            assert(!reference->package());
                            auto* package_file = input_to_output(package, reference->defining_file());
                            assert(package_file);
                            reference->set_output_status(OutputStatus::Referenced);
                            reference->add_referencing_package(*package);
                            reference->set_package_file(package_file);
                            status.mark_changed();
                            break;
                        }
                        case OutputStatus::Referenced:
                        case OutputStatus::ReferencedMarked: {
                            const auto* reference_package = reference->package();
                            assert(reference_package);
                            if (reference_package != package) {
                                // TODO: build graph of dependencies between packages and resolve the most
                                // common cases by selecting the closest common dependency package
                                reference->set_output_status(OutputStatus::MultiReferenced);
                                reference->add_referencing_package(*package);
                                status.mark_error();
                            }
                            break;
                        }
                        default:
                            assert(reference->package());
                            break;
                    }
                }
                if (!roots_only) {
                    assert(symbol->output_status() == OutputStatus::Referenced);
                    symbol->set_output_status(OutputStatus::ReferencedMarked);
                }
            }
        }
    }

    return status;
}

bool symbol_references_to_packages()
{
    auto status = symbol_references_to_packages_pass(true);
    auto error = status.error();
    while (status.changed()) {
        status = symbol_references_to_packages_pass(false);
        if (status.error()) {
            error = true;
        }
    }

    for (const auto* input_directory : inputs) {
        for (const auto* input_file : *input_directory) {
            for (auto* symbol : *input_file) {
                switch (symbol->output_status()) {
                    case OutputStatus::Undefined:
                        if (verbosity >= LogLevel::DEBUG) {
                            std::cerr << "Entity `" << symbol->name() << "` from `" << input_file->path().u8string()
                                      << "` is not used" << std::endl;
                        }
                        break;
                    case OutputStatus::Referenced:
                    case OutputStatus::ReferencedMarked:
                        assert(symbol->package());
                        assert(symbol->package_file());
                        if (verbosity >= LogLevel::TRACE) {
                            std::cerr << "Entity `" << symbol->name() << "` from `" << input_file->path().u8string()
                                      << "` is only used from `" << symbol->package()->cangjie_name()
                                      << "` package, assigning `" << symbol->package_file()->output_path().u8string()
                                      << '`' << std::endl;
                        }
                        break;
                    case OutputStatus::MultiReferenced:
                        std::cerr << "Entity `" << symbol->name() << "` from `" << input_file->path().u8string()
                                  << "` is ambiguous between " << symbol->number_of_referencing_packages()
                                  << " packages";
                        symbol->print_referencing_packages_info();
                        break;
                    default:
                        break;
                }
            }
        }
    }

    return !error;
}

void register_symbols_in_declaration_order()
{
    for (const auto* input_directory : inputs) {
        for (const auto* input_file : *input_directory) {
            for (auto* symbol : *input_file) {
                assert(check_symbol(symbol));

                if (auto* package_file = symbol->package_file()) {
                    package_file->add_symbol(symbol);
                }
            }
        }
    }
}

/** Given the N-dimensional VArray, get the type of its element */
static TypeLikeSymbol& get_element_type(const VArraySymbol& varray) noexcept
{
    const auto* subvarray = dynamic_cast<const VArraySymbol*>(varray.element_type_);
    return subvarray ? get_element_type(*subvarray) : *varray.element_type_;
}

/**
 * For each function parameter, if the type of the parameter is array, convert
 * it to pointer to its element
 */
static void decay_parameter_types(NonTypeSymbol& function)
{
    for (auto& parameter : function.parameters()) {
        auto* parameter_type = parameter.type();
        const auto* varray = dynamic_cast<const VArraySymbol*>(&parameter_type->canonical_type());
        if (varray) {
            parameter.set_type(&pointer(get_element_type(*varray)));
        }
    }
}

/**
 * For each function parameter, if the type of the parameter is array, convert
 * it to pointer to its element
 */
static void decay_parameter_types()
{
    for (auto& top_level : Universe::top_level()) {
        decay_parameter_types(top_level);
    }
    for (auto& type : Universe::type_definitions()) {
        for (auto& member : type.members()) {
            decay_parameter_types(member);
        }
    }
}

bool mark_package()
{
    decay_parameter_types();

    if (!mark_roots()) {
        return false;
    }

    add_all_symbol_references();

    if (!symbol_references_to_packages()) {
        return false;
    }

    register_symbols_in_declaration_order();

    return true;
}
